--- redirect_from: - "/03/deep-learning/aif-detection/perf-aif-lv-detection" interact_link: content/03/deep_learning/aif_detection/Perf_AIF_LV_detection.ipynb kernel_name: python3 kernel_path: content/03/deep_learning/aif_detection has_widgets: false title: |- Perfusion AIF LV detection pagenum: 2 prev_page: url: /01/abstract.html next_page: url: /03/deep_learning/analysis/PerfusionSeg.html suffix: .ipynb search: data image mr model hui xue train automated detection lv arterial input function aif series cardiac perfusion deployment scanner author nih gov load check loaded properly images masks construct loaders augmentation multi class trainer loss binary segmenation saving comment: "***PROGRAMMATICALLY GENERATED, DO NOT EDIT. SEE ORIGINAL FILES IN /content***" ---
Perfusion AIF LV detection

Automated detection of LV from arterial input function (AIF) image series for cardiac MR perfusion with model deployment to MR scanner

Author: Hui Xue <hui.xue@nih.gov>

#import os
#os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
#os.environ['CUDA_VISIBLE_DEVICES']='1,2'

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torch.onnx
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as T
from torchvision.utils import *

import numpy as np
import collections
import matplotlib.pyplot as plt
from matplotlib import animation, rc
animation.rcParams['animation.writer'] = 'ffmpeg'
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

import scipy
import scipy as sp
from scipy.spatial import ConvexHull
from scipy.ndimage.morphology import binary_fill_holes

from collections import OrderedDict
import time
from tensorboardX import SummaryWriter

from skimage import io, transform
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision

from IPython.display import display, clear_output, HTML, Image

from PIL import Image
import imp
import os
import sys
import math
import time
import random
import shutil
import scipy.misc
from glob import glob
import sklearn
import logging
from tqdm.notebook import tqdm

%matplotlib inline
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
import training
import models
import utils
import utils.cmr_ml_utils_plotting
def show(img):
    npimg = img.numpy()
    print(npimg.shape)
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

Load image data

img_dir = './data'
import scipy.io

class PerfAIFDataset(Dataset):
    """Perfusion AIF dataset."""

    def __init__(self, img_dir, which_mask='LV_RV', min_reps=64, W=48, transform=None):
        """
        Args:
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            which_mask: LV or LV_RV
        """
        self.img_dir = img_dir
        self.transform = transform
        self.which_mask = which_mask
        self.min_reps = min_reps
        self.W = W
        
        # find all images
        locations = os.listdir(self.img_dir)
        a = []
        for loc in locations:
            if(os.path.isdir(os.path.join(self.img_dir, loc))):
                a.extend(os.listdir(os.path.join(self.img_dir, loc)))

        num_samples = len(a)
        print("Found %d cases ... " % num_samples)
        
        self.initialize_storage()

        t0 = time.time()
        print("Start loading cases ... ")
        
        total_num_loaded = 0
        
        for loc in locations:
            if(os.path.isdir(os.path.join(self.img_dir, loc))):
                total_num_loaded = self.load_one_loc(loc, total_num_loaded, t0)                                          
        
        print("Total samples loaded %d " % total_num_loaded)
        
    def initialize_storage(self):
        self.aif = []
        self.lv_rv_masks = []
        self.lv_masks = []
        self.names = []
    
    def load_one_data(self, loc, f_prefix):
        
        f_name = f_prefix + '.npy'
        data = np.load(os.path.join(self.img_dir, loc, f_name))                    
                       
        # print ('Loaded ', f_name, data.shape)
        
        return data
    
    def load_one_loc(self, loc, total_num_loaded, t0):      
        
        t1 = time.time()
        
        a = os.listdir(os.path.join(self.img_dir, loc))
        num_samples = len(a)
        
        print('---> Start loading ', loc)
        tq = tqdm(total=(num_samples), file=sys.stdout)
        
        for ii, n in enumerate(a):      
            
            '''
            if (ii>30):
                break
            '''
            
            name = os.path.join(loc, n)
            #print('------> Start loading %d out of %d, %s' % (ii, num_samples, name))

            tq.set_description('loading {}, total {}'.format(ii, num_samples))
            
            try:
                Gd = self.load_one_data(name, 'aif_scc')
                try:
                    lv_rv = self.load_one_data(name, 'aif_masks_final')                    
                except:
                    lv_rv = self.load_one_data(name, 'aif_masks')
            except:
                print('------> Failed to load %d out of %d, %s' % (ii, num_samples, n))
                continue

            RO, E1, N = Gd.shape
                
            if(N<self.min_reps):
                new_Gd = np.zeros((RO, E1, self.min_reps))
                new_Gd[:,:,0:N] = Gd
                f = Gd[:,:,N-1]
                new_Gd[:,:,N:self.min_reps] = np.dstack([f]*(self.min_reps-N))
                Gd = new_Gd
            
            if(N>self.min_reps):
                Gd = Gd[:,:,0:self.min_reps]
            
            if(E1>self.W):
                s = int((E1-self.W)/2)
                Gd = Gd[:,s:s+self.W,:]
                lv_rv = lv_rv[:,s:s+self.W]
                
            if(E1<self.W):
                s = int((self.W-E1)/2)
                new_Gd = np.zeros((RO, self.W, self.min_reps))
                new_Gd[:,s:s+E1,:] = Gd                
                Gd = new_Gd
                                
                new_lv_rv = np.zeros((RO, self.W))
                new_lv_rv[:,s:s+E1] = lv_rv
                lv_rv = new_lv_rv                
            
            RO, E1, N = Gd.shape
            if(RO!=64 or E1!=48 or N!=64):
                print('--> incorrect Gd shape : ', name)
                continue
            
            Gd = Gd / np.max(Gd)
            Gd = np.transpose(Gd, (2, 0, 1))
                
            lv = np.zeros_like(lv_rv)
            lv[np.where(lv_rv==1)] = 1
                
            lv_rv = np.reshape(lv_rv, (1, lv_rv.shape[0], lv_rv.shape[1]))
            lv = np.reshape(lv, (1, lv.shape[0], lv.shape[1]))
        
            self.aif.append(Gd.astype(np.float32))
            self.lv_rv_masks.append(lv_rv.astype(np.float32))
            self.lv_masks.append(lv.astype(np.float32))

            #print('     aif data : ', Gd.shape)
            #print('     lv_rv mask : ', lv_rv.shape)
            
            self.names.append(name)
                
            total_num_loaded += 1
                
            t1 = time.time()
            
            tq.update(1)
            tq.set_postfix(loss='{:.2f}s'.format(t1-t0))
                
            #print("             Time from starting : %f seconds ... \n" % (t1-t0))
           
        str_after_loading = '    Finish loading %s --- Total %d samples -- In %.2f seconds' % (loc, num_samples, t1-t0)
        tq.set_postfix_str(str_after_loading)
        
        tq.close() 
            
        return total_num_loaded
    
    def __len__(self):
        return len(self.aif)

    def __getitem__(self, idx):
        
        if idx >= len(self.aif):
            raise "invalid index"
        
        if (self.which_mask == 'lv_rv'):            
            sample = (self.aif[idx], self.lv_rv_masks[idx], self.names[idx])
            
        if (self.which_mask == 'lv'):            
            sample = (self.aif[idx], self.lv_masks[idx], self.names[idx])
            
        if self.transform:
            sample = self.transform(sample)

        return sample    
    
    def __str__(self):
        str = "Perfusion AIF Dataset\n"
        str += "  image root: %s" % self.img_dir + "\n"
        str += "  Number of samples: %d" % len(self.aif) + "\n"
        str += "  Number of masks: %d" % len(self.lv_rv_masks) + "\n"
        if len(self.aif) > 0:
            str += "  image shape: %d %d %d" % self.aif[0].shape + "\n"
            str += "  myo mask shape: %d %d %d" % self.lv_rv_masks[0].shape + "\n"
        return str
perf_aif_dataset = PerfAIFDataset(img_dir)
print("Done")
Found 3 cases ... 
Start loading cases ... 
---> Start loading  aif
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-344-9ed7552c92df> in <module>
----> 1 perf_aif_dataset = PerfAIFDataset(img_dir)
      2 print("Done")

<ipython-input-343-7d64878830c7> in __init__(self, img_dir, which_mask, min_reps, W, transform)
     36         for loc in locations:
     37             if(os.path.isdir(os.path.join(self.img_dir, loc))):
---> 38                 total_num_loaded = self.load_one_loc(loc, total_num_loaded, t0)
     39 
     40         print("Total samples loaded %d " % total_num_loaded)

<ipython-input-343-7d64878830c7> in load_one_loc(self, loc, total_num_loaded, t0)
     63 
     64         print('---> Start loading ', loc)
---> 65         tq = tqdm(total=(num_samples), file=sys.stdout)
     66 
     67         for ii, n in enumerate(a):

~\Anaconda3\envs\qperf\lib\site-packages\tqdm\notebook.py in __init__(self, *args, **kwargs)
    246         unit_scale = 1 if self.unit_scale is True else self.unit_scale or 1
    247         total = self.total * unit_scale if self.total else self.total
--> 248         self.container = self.status_printer(self.fp, total, self.desc, self.ncols)
    249         self.container.pbar = self
    250         if display_here:

~\Anaconda3\envs\qperf\lib\site-packages\tqdm\notebook.py in status_printer(_, total, desc, ncols)
    117                 "/user_install.html")
    118         if total:
--> 119             pbar = IProgress(min=0, max=total)
    120         else:  # No total? Show info style bar with no progress tqdm status
    121             pbar = IProgress(min=0, max=1)

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in __new__(cls, *args, **kwargs)
    956         else:
    957             inst = new_meth(cls, *args, **kwargs)
--> 958         inst.setup_instance(*args, **kwargs)
    959         return inst
    960 

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in setup_instance(self, *args, **kwargs)
    984         self._trait_notifiers = {}
    985         self._trait_validators = {}
--> 986         super(HasTraits, self).setup_instance(*args, **kwargs)
    987 
    988     def __init__(self, *args, **kwargs):

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in setup_instance(self, *args, **kwargs)
    975             else:
    976                 if isinstance(value, BaseDescriptor):
--> 977                     value.instance_init(self)
    978 
    979 

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in instance_init(self, obj)
   1690     def instance_init(self, obj):
   1691         self._resolve_classes()
-> 1692         super(Instance, self).instance_init(obj)
   1693 
   1694     def _resolve_classes(self):

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in instance_init(self, obj)
    518         # use provides a static default, transfer that to obj._trait_values.
    519         with obj.cross_validation_lock:
--> 520             if (self._dynamic_default_callable(obj) is None) \
    521                     and (self.default_value is not Undefined):
    522                 v = self._validate(obj, self.default_value)

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in _dynamic_default_callable(self, obj)
    502             mro = type(obj).mro()
    503             meth_name = '_%s_default' % self.name
--> 504             for cls in mro[:mro.index(self.this_class) + 1]:
    505                 if hasattr(cls, '_trait_default_generators'):
    506                     default_handler = cls._trait_default_generators.get(self.name)

ValueError: <class 'ipywidgets.widgets.widget.Widget'> is not in list
print(perf_aif_dataset)
Perfusion AIF Dataset
  image root: ./data
  Number of samples: 3
  Number of masks: 3
  image shape: 64 64 48
  myo mask shape: 1 64 48

perf_aif_dataset.which_mask = 'lv_rv'


B = len(perf_aif_dataset)

for n in range(B):
    Gd, masks, names = perf_aif_dataset[n]
    N, RO, E1 = Gd.shape
    if(RO!=64 or E1!=48 or N!=64):
        print('--> incorrect Gd shape : ', (n, names, Gd.shape))
    _, RO, E1 = masks.shape
    if(RO!=64 or E1!=48):
        print('--> incorrect masks shape : ', (n, names, masks.shape))

Check to see if data loaded properly

B = len(perf_aif_dataset)
print(B)
ni = np.zeros(B)
nm = np.zeros(B)

for n in np.arange(B):
    images, masks, names = perf_aif_dataset[n]
    ni[n] = np.linalg.norm(images)
    nm[n] = np.linalg.norm(masks)

    if(ni[n] < 1):
        print(names, ni[n])
    
#Original Plot from research paper    
#plt.figure(figsize=(8,8))    
#plt.plot(ni, nm, 'bo')
3
# Plotly
# x and y given as array_like objects
import plotly.graph_objects as go
import plotly
from IPython.core.display import display, HTML
import pandas as pd
import plotly.express as px
fig = px.scatter(x=ni, y=nm)

plotly.offline.plot(fig, filename = 'figure_1.html')
display(HTML('figure_1.html'))
print(perf_aif_dataset)
NUM_TRAIN = len(perf_aif_dataset)-1

batch_size = 2

loader_for_train = DataLoader(perf_aif_dataset, batch_size=batch_size, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

loader_for_val = DataLoader(perf_aif_dataset, batch_size=batch_size, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, len(perf_aif_dataset))))
Perfusion AIF Dataset
  image root: ./data
  Number of samples: 3
  Number of masks: 3
  image shape: 64 64 48
  myo mask shape: 1 64 48

iter_train = iter(loader_for_train)
print(iter_train)
images, masks, names = iter_train.next()

print(images.shape)
print(masks.shape)

ia = np.transpose(images, (2, 3, 1, 0))
print(ia.shape)

#Original Plot from research paper
# for n in range(masks.shape[0]):
#     a = utils.cmr_ml_utils_plotting.plot_image_array(np.transpose(np.squeeze(images[n, 0:-1:4,:,:]), (1, 2, 0)), columns=16, figsize=[32,8])
    
# a = utils.cmr_ml_utils_plotting.plot_image_array(np.transpose(np.squeeze(masks), (1, 2, 0)), columns=1, figsize=[16,16])
<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x000001CBC9A405C8>
torch.Size([2, 64, 64, 48])
torch.Size([2, 1, 64, 48])
torch.Size([64, 48, 64, 2])

Images

#Plotly
import plotly.graph_objects as go
import numpy as np

# Create figure
fig = go.Figure()

# Add traces, one for each slider step
for step in range(15):
    fig.add_trace(
        go.Heatmap(z=np.squeeze(images[0, 0:-1:4,:,:])[step], colorscale="Gray")    
    )

# Make 10th trace visible
#fig.data[15].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data)):
    step = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)},
              {"title": "Slider switched to image: " + str(i)}],  # layout attribute
    )
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=0,
    currentvalue={"prefix": "Image: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)

fig.update_layout(
    width=500,
    height=600,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0)
)

fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": -5},
            showactive=True,
            x=0.3,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": -5},
            showactive=True,
            x=0.8,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=0.1, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.8, xref="paper", y=1.1,
                             yref="paper", showarrow=False)
    ])



plotly.offline.plot(fig, filename = 'figure_2.html')
display(HTML('figure_2.html'))
#Plotly
import plotly.graph_objects as go
import numpy as np

# Create figure
fig = go.Figure()

# Add traces, one for each slider step
for step in range(15):
    fig.add_trace(
        go.Heatmap(z=np.squeeze(images[1, 0:-1:4,:,:])[step], colorscale="Gray")    
    )

# Make 10th trace visible
#fig.data[15].visible = True

# Create and add slider
steps = []
for i in range(len(fig.data)):
    step = dict(
        method="update",
        args=[{"visible": [False] * len(fig.data)},
              {"title": "Slider switched to image: " + str(i)}],  # layout attribute
    )
    step["args"][0]["visible"][i] = True  # Toggle i'th trace to "visible"
    steps.append(step)

sliders = [dict(
    active=0,
    currentvalue={"prefix": "Image: "},
    pad={"t": 50},
    steps=steps
)]

fig.update_layout(
    sliders=sliders
)

fig.update_layout(
    width=500,
    height=600,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0)
)

fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": -5},
            showactive=True,
            x=0.3,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": -5},
            showactive=True,
            x=0.8,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=0.1, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.8, xref="paper", y=1.1,
                             yref="paper", showarrow=False)
    ])



plotly.offline.plot(fig, filename = 'figure_3.html')
display(HTML('figure_3.html'))

Masks

# Plotly
import plotly.graph_objects as go

import pandas as pd

# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")

# Create figure
fig = go.Figure()

# Add surface trace
fig.add_trace(
     go.Heatmap(z=np.squeeze(masks)[1],  name = 'Figure 1 - Mask', colorscale="Gray")
)

fig.add_trace(
     go.Heatmap(z=np.squeeze(masks)[0],  name = 'Figure 2 - Mask', colorscale="Gray")
)

# Update plot sizing
fig.update_layout(
    width=800,
    height=900,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0),
)

# Update 3D scene options
fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.37,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            active=1,
        buttons=list(
        [
            dict(label = 'Figure 1 - Mask',
              method = 'update',
              args = [{'visible': [True, False]},
                      {'title': 'Figure 1 - Mask',
                       'showlegend':True}]),
          dict(label = 'Figure 2 - Mask',
              method = 'update',
              args = [{'visible': [False, True]},
                      {'title': 'Figure 2 - Mask',
                       'showlegend':True}]),
        ]
        ))
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=-0.02, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
                             yref="paper", showarrow=False)
    ])


plotly.offline.plot(fig, filename = 'figure_4.html')
display(HTML('figure_4.html'))

Construct data loaders

USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
using device: cpu

Data augmentation

class RandomFlip1stDim(object):
    """Randomly flip the first dimension of numpy array.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img ([N RO E1 ... ]): Image to be flipped.
        Returns:
            res: Randomly flipped image.
        """
        #print(img[0].shape)
        #print(img[1].shape)
            
        if random.random() < self.p: 
                                
            a = np.transpose(img[0], [1, 2, 0])
            a = np.flipud(a)
            a = np.transpose(a, [2, 0, 1])
            
            b = np.transpose(img[1], [1, 2, 0])
            b = np.flipud(b)
            b = np.transpose(b, [2, 0, 1])
            return ( a.copy(), b.copy(), img[2] )
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)
    
class RandomFlip2ndDim(object):
    """Randomly flip the second dimension of numpy array.
    Args:
        p (float): probability of the image being flipped. Default value is 0.5
    """

    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        """
        Args:
            img ([N RO E1 ... ]): Image to be flipped.
        Returns:
            res: Randomly flipped image.
        """
        if random.random() < self.p:    
            a = np.transpose(img[0], [1, 2, 0])
            a = np.fliplr(a)
            a = np.transpose(a, [2, 0, 1])
            
            b = np.transpose(img[1], [1, 2, 0])
            b = np.fliplr(b)
            b = np.transpose(b, [2, 0, 1])
            return ( a.copy(), b.copy(), img[2] )
        return img

    def __repr__(self):
        return self.__class__.__name__ + '(p={})'.format(self.p)
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import scipy
from scipy import ndimage

# probs should be a 64x48 torch tensor
def adaptive_thresh_cpu(probs, p_thresh=0.5, p_thresh_max=0.988):
    # Try regular adaptive thresholding first
    #p_thresh_max  = 0.988 # <-- Should not be too close to 1 to ensure while loop does not go over.

    p_thresh_incr = 0.01
    #p_thresh = 0.5

    RO = probs.shape[0]
    E1 = probs.shape[1]

    try:
        number_of_blobs = float("inf")
        blobs = np.zeros((RO,E1))
        while number_of_blobs > 1 and p_thresh < p_thresh_max:
            mask = (probs > torch.max(probs) * p_thresh).float()
            blobs, number_of_blobs = ndimage.label(mask)
            p_thresh += p_thresh_incr  # <-- Note this line can lead to float drift.
    
        if(number_of_blobs == 1):
            return mask

        if(number_of_blobs == 0):
            mask = np.zeros((RO, E1))
            print("adaptive_thresh_cpu, did not find any blobs ... ", file=sys.stderr)
            sys.stderr.flush()
            return mask

        ## If we are here then we cannot isolate a singular blob as the LV.
        ## Select the largest blob as the final mask.
        biggest_blob = (0, torch.zeros(RO,E1))
        for i in range(number_of_blobs):
            one_blob = torch.tensor((blobs == i+1).astype(int), dtype=torch.uint8)
            area = torch.sum(one_blob)
            if(area > biggest_blob[0]):
                biggest_blob = (area, one_blob)

        return biggest_blob[1]

    except Exception as e:
        print("Error happened in adaptive_thresh_cpu ...")
        print(e)
        sys.stderr.flush()
        mask = np.zeros((RO,E1))

    return mask
def compute_dice_scores(best_model, loader, aif_trainer, binary_seg=False):
    # get dice for all LV in validation set
    dice_scores = []
    cases = []
    best_model = best_model.cuda() 

    ind = 0

    best_model.eval()  # set model to evaluation mode
           
    for t, (x, y, names) in enumerate(loader):        

        x = x.to(aif_trainer.x_dtype).cuda() 
        y = y.to(aif_trainer.y_dtype).cuda()
       
        with torch.no_grad():
            scores = best_model(x)

        # loss = aif_trainer.compute_loss(scores, y)

        if(binary_seg):
            probs = F.sigmoid(scores)
        else:
            m = torch.nn.Softmax(dim=1)
            probs = m(scores)

        probs = probs.cpu().detach()

        N = x.shape[0]
        
        aif_mask = y.cpu().detach().numpy()        
            
        for n in range(N):
            
            if(binary_seg):
                lv_probs = probs[n, 0, :, :]
            else:
                lv_probs = probs[n, 1, :, :]
            # lv_mask = training.adaptive_thresh(lv_probs, device=torch.device('cpu'), p_thresh=0.5)
            #lv_mask = lv_mask.cpu().detach().numpy()
            
            #lv_probs = lv_probs.cpu().detach().numpy()
            #lv_probs = np.squeeze(lv_probs)
            
            lv_probs = lv_probs.cpu().detach()
            lv_mask = adaptive_thresh_cpu(lv_probs, p_thresh=0.5)
                        
            lv_mask = lv_mask.detach().numpy()
            lv_aif_mask = np.zeros(lv_mask.shape)
            lv_aif_mask[np.where(np.squeeze(aif_mask[n, 0, :,:]==1))] = 1

            ds = training.dice(lv_aif_mask, lv_mask)

            if(ds<0.1):
                print(names[n])
                curr_probs = probs[n, :, :, :].numpy()
                curr_probs = np.transpose(curr_probs, (2, 1, 0))
                a = utils.cmr_ml_utils_plotting.plot_image_array(np.squeeze(curr_probs), columns=8, figsize=[16,16])
                a = utils.cmr_ml_utils_plotting.plot_image_array(lv_aif_mask, columns=8, figsize=[16,16])
                a = utils.cmr_ml_utils_plotting.plot_image_array(lv_mask, columns=8, figsize=[16,16])
    
            dice_scores.append(ds)
            cases.append(names[n])
    
    return dice_scores, cases
def get_failed_cases(dice_scores, cases, thres=0.5, print_failed=True):
    total_samples = len(dice_scores)
    sucess_samples = 0
    failed_cases = []
    failed_dices = []
    for k in range(total_samples):
        if(dice_scores[k]>=thres):
            sucess_samples = sucess_samples + 1
        else:
            if(print_failed):
                print("case %s, dice %f " % (cases[k], dice_scores[k]))
                
            failed_cases.append(cases[k])
            failed_dices.append(dice_scores[k])

    success_rate = sucess_samples/total_samples
    print("Total test samples is ", total_samples)  
    print("Success rate is ", success_rate)  
    
    return failed_cases, failed_dices, success_rate
def load_apply_model_multi_class(img_dir, case_name, model, device):
    
    data_dir = os.path.join(img_dir, case_name)

    model_device = device
    
    print(data_dir)

    Gd = np.load(os.path.join(data_dir, 'aif.npy'))
    RO, E1, N = Gd.shape

    try:
        aif_mask = np.load(os.path.join(data_dir, 'aif_masks_final.npy'))
    except:
        aif_mask = np.load(os.path.join(data_dir, 'aif_masks.npy'))

    Gd = Gd[:,:,0:64]
    s = int((E1-48)/2)
    Gd = Gd[:,s:s+48,:]

    Gd = np.transpose(Gd, (2, 0, 1))
    Gd = np.reshape(Gd, (1, Gd.shape[0], Gd.shape[1], Gd.shape[2]))
    Gd /= np.max(Gd)

    aif = torch.from_numpy(Gd).float()
    aif = aif.to(device=model_device)
    model.eval() 
    with torch.no_grad():
        scores = model(aif)

    m = torch.nn.Softmax(dim=1)
    probs = m(scores)

    probs = probs.cpu().detach()
    
    lv_probs = probs[0, 1, :, :]
    lv_mask = training.adaptive_thresh(lv_probs, device=torch.device('cpu'), p_thresh=0.5)
    
    lv_mask = lv_mask.cpu().detach().numpy()
    probs = probs.cpu().detach().numpy()
    probs = np.squeeze(np.transpose(probs, (2, 3, 1, 0)))
    
    a = utils.cmr_ml_utils_plotting.plot_image_array(probs, columns=8, figsize=[16,16])
    a = utils.cmr_ml_utils_plotting.plot_image_array(aif_mask, columns=8, figsize=[16,16])
    a = utils.cmr_ml_utils_plotting.plot_image_array(lv_mask, columns=8, figsize=[16,16])
    
    return probs, lv_mask, aif_mask

Train with multi-class trainer

# data augmenation for random flipping
transform = torchvision.transforms.Compose([RandomFlip1stDim(0.5), RandomFlip2ndDim(0.5)])
perf_aif_dataset.transform = transform
perf_aif_dataset.which_mask = 'lv_rv'
num_classes = 3
class_for_accu = [1, 2]
class_weights = np.ones(num_classes)
class_weights[1] = 5
p_thres = [0.5, 0.5, 0.75]
print(perf_aif_dataset.which_mask)

sample = perf_aif_dataset[1]

print(sample[1].shape)

#Original Plot from research paper
# plt.figure()
# plt.imshow(np.squeeze(sample[1]))
lv
(1, 64, 48)
#Plotly
#fig = px.imshow(np.squeeze(sample[1]), binary_string=True)

# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")

# Create figure
fig = go.Figure()

# Add surface trace
fig.add_trace(go.Heatmap(z=np.squeeze(sample[1]), colorscale="Gray"))



# Update plot sizing
fig.update_layout(
    width=800,
    height=900,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0),
)

# Update 3D scene options
fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.37,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=0, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
                             yref="paper", showarrow=False)
    ])


plotly.offline.plot(fig, filename = 'figure_5.html')
display(HTML('figure_5.html'))
images, masks, names = iter_train.next()

B, C, RO, E1 = images.shape

print(images.shape)
print(masks.shape)
print(torch.max(images))
print(torch.max(masks))

a = images[:,32,:,:]
a = torch.reshape(a, (B, 1, RO, E1))

plt.figure(figsize=(8, 8))
show(make_grid(a.double(), nrow=8, padding=2, normalize=False, scale_each=True))

plt.figure(figsize=(8, 8))
show(make_grid(masks.double(), nrow=8, padding=2, normalize=True, scale_each=False))

X = images.type(torch.FloatTensor)
y = masks.type(torch.FloatTensor)
print(X.shape)
print(y.shape)
def perform_training(hyperpara, perf_aif_dataset, loader_for_train, loader_for_val):
    
    num_epochs = hyperpara['num_epochs']
    print_every = 100000

    inplanes = hyperpara['inplanes']
    layers = hyperpara['layers']
    layers_planes = hyperpara['layers_planes']
    
    class_weights = hyperpara['class_weights']
    jaccard_weight = hyperpara['jaccard_weight']
    
    print('======================================================')
    print('num_epochs ', num_epochs)
    print('inplanes ', inplanes)
    print('layers ', layers)
    print('layers_planes ', layers_planes)
    print('class_weights ', class_weights)
    print('jaccard_weight ', jaccard_weight)
    print('======================================================')
    
    perf_aif_dataset.which_mask = 'lv_rv'
    num_classes = 3
    class_for_accu = [1, 2]
    p_thres = [0.5, 0.5, 0.75]

    print(perf_aif_dataset.aif[0].shape)
    C, H, W = perf_aif_dataset.aif[0].shape

    model = models.GadgetronResUnet18(F0=C, 
                              inplanes=inplanes, 
                              layers=layers, 
                              layers_planes=layers_planes, 
                              use_dropout=False, 
                              p=0.5, 
                              H=H, W=W, C=num_classes,
                              verbose=True)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        print("model on multiple GPU ... ")

    patience = 10
    factor = 0.5
    cooldown = 3
    min_lr = 1e-7

    weight_decay=0
    learning_rate = 1e-3

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)

    # optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)

    CW = np.ones(num_classes)
    CW[1] = class_weights
    criterion = training.LossMulti(class_weights=CW, jaccard_weight=jaccard_weight)

    log_dir = 'aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
    writer = SummaryWriter(log_dir)
    
    aif_trainer = training.GadgetronMultiClassSeg_Perf(model, 
                                   optimizer, 
                                   criterion, 
                                   loader_for_train, 
                                   loader_for_val, 
                                   class_for_accu=class_for_accu,
                                   p_thres = p_thres,
                                   scheduler=scheduler, 
                                   epochs=num_epochs, 
                                   device=device, 
                                   x_dtype=torch.float32, 
                                   y_dtype=torch.long, 
                                   early_stopping_thres = 100,                              
                                   print_every=print_every, 
                                   writer=writer, 
                                   model_folder="perf_training/")
    
    
    epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)
    
    dice_scores, cases = compute_dice_scores(best_model, loader_for_val, aif_trainer)
    failed_cases, failed_dices, success_rate = get_failed_cases(dice_scores, cases, thres=0.5, print_failed=False)   
    scipy.io.savemat(os.path.join(img_dir, 'perf_aif_lv_rv_val_failed.mat'), {"cases":failed_cases, "dices":failed_dices, "dice_scores":dice_scores, "cases":cases})
    
    dice_scores_train, cases_train = compute_dice_scores(best_model, loader_for_train, aif_trainer)
    failed_cases_train, failed_dices_train, success_rate_train = get_failed_cases(dice_scores_train, cases_train, thres=0.5, print_failed=False) 

    hyperpara['best_model'] = best_model
    hyperpara['epochs_traning'] = epochs_traning
    hyperpara['epochs_validation'] = epochs_validation
    hyperpara['loss_all'] = loss_all
    hyperpara['epochs_acc_class'] = epochs_acc_class
    
    hyperpara['dice_scores'] = dice_scores
    hyperpara['cases'] = cases
    hyperpara['failed_cases'] = failed_cases
    hyperpara['failed_dices'] = failed_dices
    hyperpara['success_rate'] = success_rate
    
    hyperpara['dice_scores_train'] = dice_scores
    hyperpara['cases_train'] = cases
    
    return hyperpara
# hyper parameter search
layers_planes = [[96, 128], [128, 128], [128, 256]]
layers = [[2, 3], [3, 3], [3, 4], [4, 4], [4, 5]]
inplanes = [64, 96, 128]

best_success_rate = 0
best_hyperpara = None

hyperpara_all = []

for a in range(len(layers_planes)):
    for b in range(len(layers)):
        for c in range(len(inplanes)):
            
            print('-----------------------------------------------')
            print(a, b, c)
            
            hyperpara = dict()
            
            hyperpara['num_epochs'] = 40
            hyperpara['inplanes'] = inplanes[c]
            hyperpara['layers'] = layers[b]
            hyperpara['layers_planes'] = layers_planes[a]
    
            hyperpara['class_weights'] = 5.0
            hyperpara['jaccard_weight'] = 0.5
            
            k = 12
            batch_size = 256

            # Chunk into k random sets
            chunks = chunk(range(len(perf_aif_dataset)), k)
            listified_chunks = list(chunks)

            val_idxs = listified_chunks[0]
            train_idxs = listified_chunks[1:]
            train_idxs = [item for sublist in train_idxs for item in sublist]

            num_train = len(train_idxs)
            num_val = len(val_idxs)

            loader_for_train = DataLoader(perf_aif_dataset, batch_size=batch_size, 
                                      sampler=sampler.SubsetRandomSampler(train_idxs))

            loader_for_val = DataLoader(perf_aif_dataset, batch_size=batch_size, 
                                    sampler=sampler.SubsetRandomSampler(val_idxs))

            num_train = len(train_idxs)
            print('num_train = %d' % num_train)
            num_val = len(val_idxs)
            print('num_val = %d' % num_val)
            
            hyperpara = perform_training(hyperpara, perf_aif_dataset, loader_for_train, loader_for_val)
            
            # print(hyperpara)
            print('success rate - ', hyperpara['success_rate'])
            
            if(hyperpara['success_rate']>best_success_rate):
                best_success_rate = hyperpara['success_rate']
                best_hyperpara = hyperpara
                
            hyperpara_all.append(hyperpara)
-----------------------------------------------
0 0 0
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-367-c7af734a631f> in <module>
     30 
     31             # Chunk into k random sets
---> 32             chunks = chunk(range(len(perf_aif_dataset)), k)
     33             listified_chunks = list(chunks)
     34 

NameError: name 'chunk' is not defined
num_epochs = 50
print_every = 100000

# resnet
inplanes = 96
layers=[4, 4]
layers_planes=[128, 128]
growth_rate = 8

#dense net
inplanes = 64
layers=[3, 3]
layers_planes=[16, 32]
growth_rate = 16

# resnet, small
inplanes = 96
layers=[2, 3]
layers_planes=[128, 128]
growth_rate = 8

print(perf_aif_dataset.aif[0].shape)
C, H, W = perf_aif_dataset.aif[0].shape

model = models.GadgetronResUnet18(F0=C, 
                          inplanes=inplanes, 
                          layers=layers, 
                          layers_planes=layers_planes, 
                          use_dropout=False, 
                          p=0.5, 
                          H=H, W=W, C=num_classes,
                          verbose=True)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print("model on multiple GPU ... ")

patience = 10
factor = 0.5
cooldown = 3
min_lr = 1e-7

weight_decay=0
learning_rate = 1e-3

optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)

# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)

criterion = training.LossMulti(class_weights=class_weights, jaccard_weight=0.5)
# criterion = nn.BCEWithLogitsLoss()
# criterion = nn.BCELoss()

log_dir = './aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
writer = SummaryWriter(log_dir)
(64, 64, 48)
GadgetronResUnet : F0=64, inplanes=96
------------------------------------------------------------
    GadgetronResUnetInputBlock : input size (64, 64, 48), output size (96, 64, 48) --> (96, 64, 48)
------------------------------------------------------------
    GadgetronResUnet, down layer 0:
        GadgetronResUnet, down layer (64, 48) -> (32, 24)
        GadgetronResUnetBasicBlock : input size (96, 32, 24), output size (128, 32, 24) --> (128, 32, 24)
        GadgetronResUnetBasicBlock : input size (128, 32, 24), output size (128, 32, 24) --> (128, 32, 24)
    GadgetronResUnet, down layer 1:
        GadgetronResUnet, down layer (32, 24) -> (16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
------------------------------------------------------------
    GadgetronResUnet, bridge layer (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnet, down layer (16, 12) -> (8, 6)
        GadgetronResUnetBasicBlock : input size (128, 8, 6), output size (128, 8, 6) --> (128, 8, 6)
        GadgetronResUnetBasicBlock : input size (128, 8, 6), output size (128, 8, 6) --> (128, 8, 6)
        GadgetronResUnetBasicBlock : input size (128, 8, 6), output size (128, 8, 6) --> (128, 8, 6)
------------------------------------------------------------
    GadgetronResUnet, up layer 0:
        GadgetronResUnet_UpSample : input size (256, 8, 6), upsampled size (256, 16, 12)
        GadgetronResUnetBasicBlock : input size (256, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
    GadgetronResUnet, up layer 1:
        GadgetronResUnet_UpSample : input size (256, 16, 12), upsampled size (256, 32, 24)
        GadgetronResUnetBasicBlock : input size (256, 32, 24), output size (96, 32, 24) --> (96, 32, 24)
        GadgetronResUnetBasicBlock : input size (96, 32, 24), output size (96, 32, 24) --> (96, 32, 24)
    GadgetronResUnet, up layer 2:
        GadgetronResUnet_UpSample : input size (192, 32, 24), upsampled size (192, 64, 48)
        GadgetronResUnetBasicBlock : input size (192, 64, 48), output size (96, 64, 48) --> (96, 64, 48)
        GadgetronResUnetBasicBlock : input size (96, 64, 48), output size (96, 64, 48) --> (96, 64, 48)
------------------------------------------------------------
Output layer (96, 64, 48) --> (3, 64, 48)
------------------------------------------------------------
aif_trainer = training.GadgetronMultiClassSeg_Perf(model, 
                                   optimizer, 
                                   criterion, 
                                   loader_for_train, 
                                   loader_for_val, 
                                   class_for_accu=class_for_accu,
                                   p_thres = p_thres,
                                   scheduler=scheduler, 
                                   epochs=num_epochs, 
                                   device=device, 
                                   x_dtype=torch.float32, 
                                   y_dtype=torch.long, 
                                   early_stopping_thres = 100,                              
                                   print_every=print_every,
                                   small_data_mode = False, 
                                   writer=writer, 
                                   model_folder="aif_training/")
epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)
50
Start training ... 
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)
----------------------------------------
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-373-ebf2cf2d36fb> in <module>
----> 1 epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)

~\Desktop\New folder (2) temp\deep_learning\aif_detection\training\training_base.py in train(self, verbose, epoch_to_load, save_model_epoch)
    113 
    114             random.seed()
--> 115             tq = tqdm_notebook(total=(len(self.loader_for_train) * batch_size), file=sys.stdout)
    116             tq.set_description('Epoch {}, total {}'.format(e, self.epochs))
    117 

~\Anaconda3\envs\qperf\lib\site-packages\tqdm\__init__.py in tqdm_notebook(*args, **kwargs)
     25          "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`",
     26          TqdmDeprecationWarning, stacklevel=2)
---> 27     return _tqdm_notebook(*args, **kwargs)
     28 
     29 

~\Anaconda3\envs\qperf\lib\site-packages\tqdm\notebook.py in __init__(self, *args, **kwargs)
    246         unit_scale = 1 if self.unit_scale is True else self.unit_scale or 1
    247         total = self.total * unit_scale if self.total else self.total
--> 248         self.container = self.status_printer(self.fp, total, self.desc, self.ncols)
    249         self.container.pbar = self
    250         if display_here:

~\Anaconda3\envs\qperf\lib\site-packages\tqdm\notebook.py in status_printer(_, total, desc, ncols)
    117                 "/user_install.html")
    118         if total:
--> 119             pbar = IProgress(min=0, max=total)
    120         else:  # No total? Show info style bar with no progress tqdm status
    121             pbar = IProgress(min=0, max=1)

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in __new__(cls, *args, **kwargs)
    956         else:
    957             inst = new_meth(cls, *args, **kwargs)
--> 958         inst.setup_instance(*args, **kwargs)
    959         return inst
    960 

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in setup_instance(self, *args, **kwargs)
    984         self._trait_notifiers = {}
    985         self._trait_validators = {}
--> 986         super(HasTraits, self).setup_instance(*args, **kwargs)
    987 
    988     def __init__(self, *args, **kwargs):

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in setup_instance(self, *args, **kwargs)
    975             else:
    976                 if isinstance(value, BaseDescriptor):
--> 977                     value.instance_init(self)
    978 
    979 

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in instance_init(self, obj)
   1690     def instance_init(self, obj):
   1691         self._resolve_classes()
-> 1692         super(Instance, self).instance_init(obj)
   1693 
   1694     def _resolve_classes(self):

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in instance_init(self, obj)
    518         # use provides a static default, transfer that to obj._trait_values.
    519         with obj.cross_validation_lock:
--> 520             if (self._dynamic_default_callable(obj) is None) \
    521                     and (self.default_value is not Undefined):
    522                 v = self._validate(obj, self.default_value)

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in _dynamic_default_callable(self, obj)
    502             mro = type(obj).mro()
    503             meth_name = '_%s_default' % self.name
--> 504             for cls in mro[:mro.index(self.this_class) + 1]:
    505                 if hasattr(cls, '_trait_default_generators'):
    506                     default_handler = cls._trait_default_generators.get(self.name)

ValueError: <class 'ipywidgets.widgets.widget.Widget'> is not in list

Loss

#Original Plot from research paper
#fig = plt.figure()
#plt.plot(loss_all[0:500,0], loss_all[0:500,1])
#Plotly
fig = go.Figure(data=go.Scatter(x=loss_all[0:500,0], y=loss_all[0:500,1], name='Loss'))

plotly.offline.plot(fig, filename = 'figure_6.html')
display(HTML('figure_6.html'))
acc, loss, acc_class = aif_trainer.check_validation_test_accuracy(loader_for_val, best_model)
print(acc, loss)
print(acc_class)
3.6928033750882605e-06 2.142615556716919
[2.94117217e-06 4.44443458e-06]

Train with binary segmenation

# data augmenation for random flipping
transform = torchvision.transforms.Compose([RandomFlip1stDim(0.5), RandomFlip2ndDim(0.5)])
perf_aif_dataset.transform = transform
perf_aif_dataset.which_mask = 'lv'
num_classes = 1
p_thres = 0.5
print(perf_aif_dataset.which_mask)

sample = perf_aif_dataset[1]

print(sample[1].shape)

#Original Plot from research paper
#plt.figure()
#plt.imshow(np.squeeze(sample[1]))
lv
(1, 64, 48)
#Plotly

#fig = px.imshow(np.squeeze(sample[1]), binary_string=True)

# load dataset
#df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/volcano.csv")

# Create figure
fig = go.Figure()

# Add surface trace
fig.add_trace(go.Heatmap(z=np.squeeze(sample[1]), colorscale="Gray"))



# Update plot sizing
fig.update_layout(
    width=800,
    height=900,
    autosize=False,
    margin=dict(t=100, b=0, l=0, r=0),
)

# Update 3D scene options
fig.update_scenes(
    aspectratio=dict(x=1, y=1, z=0.7),
    aspectmode="manual"
)

# Add dropdowns
button_layer_1_height = 1.08
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["colorscale", "Gray"],
                    label="Gray",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Cividis"],
                    label="Cividis",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Blues"],
                    label="Blues",
                    method="restyle"
                ),
                dict(
                    args=["colorscale", "Viridis"],
                    label="Viridis",
                    method="restyle"
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(
                    args=["reversescale", False],
                    label="False",
                    method="restyle"
                ),
                dict(
                    args=["reversescale", True],
                    label="True",
                    method="restyle"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.37,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
    ]
)

fig.update_layout(
    annotations=[
        dict(text="colorscale", x=0, xref="paper", y=1.06, yref="paper",
                             align="left", showarrow=False),
        dict(text="Reverse<br>Colorscale", x=0.25, xref="paper", y=1.07,
                             yref="paper", showarrow=False)
    ])


plotly.offline.plot(fig, filename = 'figure_7.html')
display(HTML('figure_7.html'))
images, masks, names = iter_train.next()

B, C, RO, E1 = images.shape

print(images.shape)
print(masks.shape)
print(torch.max(images))
print(torch.max(masks))

a = images[:,32,:,:]
a = torch.reshape(a, (B, 1, RO, E1))

plt.figure(figsize=(16, 16))
show(make_grid(a.double(), nrow=8, padding=2, normalize=False, scale_each=True))

plt.figure(figsize=(16, 16))
show(make_grid(masks.double(), nrow=8, padding=2, normalize=True, scale_each=False))

X = images.type(torch.FloatTensor)
y = masks.type(torch.FloatTensor)
print(X.shape)
print(y.shape)
num_epochs = 50
print_every = 100000

inplanes = 96
layers=[4, 4]
layers_planes=[128, 128]

print(perf_aif_dataset.aif[0].shape)
C, H, W = perf_aif_dataset.aif[0].shape

model = models.GadgetronResUnet18(F0=C, 
                          inplanes=inplanes, 
                          layers=layers, 
                          layers_planes=layers_planes, 
                          use_dropout=False, 
                          p=0.5, 
                          H=H, W=W, C=num_classes,
                          verbose=True)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print("model on multiple GPU ... ")

patience = 10
factor = 0.5
cooldown = 3
min_lr = 1e-7

weight_decay=0
learning_rate = 1e-3

optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=weight_decay, amsgrad=False)

# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, cooldown=cooldown, min_lr=min_lr, verbose=True)

criterion = training.LossBinary(jaccard_weight=0.5)
# criterion = training.LossMulti(class_weights=class_weights, jaccard_weight=0.5)

log_dir = 'aif_training/ResUnet' + '_lr_' + str(learning_rate) + '_epochs_' + str(num_epochs)
writer = SummaryWriter(log_dir)
(64, 64, 48)
GadgetronResUnet : F0=64, inplanes=96
------------------------------------------------------------
    GadgetronResUnetInputBlock : input size (64, 64, 48), output size (96, 64, 48) --> (96, 64, 48)
------------------------------------------------------------
    GadgetronResUnet, down layer 0:
        GadgetronResUnet, down layer (64, 48) -> (32, 24)
        GadgetronResUnetBasicBlock : input size (96, 32, 24), output size (128, 32, 24) --> (128, 32, 24)
        GadgetronResUnetBasicBlock : input size (128, 32, 24), output size (128, 32, 24) --> (128, 32, 24)
        GadgetronResUnetBasicBlock : input size (128, 32, 24), output size (128, 32, 24) --> (128, 32, 24)
        GadgetronResUnetBasicBlock : input size (128, 32, 24), output size (128, 32, 24) --> (128, 32, 24)
    GadgetronResUnet, down layer 1:
        GadgetronResUnet, down layer (32, 24) -> (16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
------------------------------------------------------------
    GadgetronResUnet, bridge layer (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnet, down layer (16, 12) -> (8, 6)
        GadgetronResUnetBasicBlock : input size (128, 8, 6), output size (128, 8, 6) --> (128, 8, 6)
        GadgetronResUnetBasicBlock : input size (128, 8, 6), output size (128, 8, 6) --> (128, 8, 6)
        GadgetronResUnetBasicBlock : input size (128, 8, 6), output size (128, 8, 6) --> (128, 8, 6)
        GadgetronResUnetBasicBlock : input size (128, 8, 6), output size (128, 8, 6) --> (128, 8, 6)
------------------------------------------------------------
    GadgetronResUnet, up layer 0:
        GadgetronResUnet_UpSample : input size (256, 8, 6), upsampled size (256, 16, 12)
        GadgetronResUnetBasicBlock : input size (256, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
        GadgetronResUnetBasicBlock : input size (128, 16, 12), output size (128, 16, 12) --> (128, 16, 12)
    GadgetronResUnet, up layer 1:
        GadgetronResUnet_UpSample : input size (256, 16, 12), upsampled size (256, 32, 24)
        GadgetronResUnetBasicBlock : input size (256, 32, 24), output size (96, 32, 24) --> (96, 32, 24)
        GadgetronResUnetBasicBlock : input size (96, 32, 24), output size (96, 32, 24) --> (96, 32, 24)
        GadgetronResUnetBasicBlock : input size (96, 32, 24), output size (96, 32, 24) --> (96, 32, 24)
        GadgetronResUnetBasicBlock : input size (96, 32, 24), output size (96, 32, 24) --> (96, 32, 24)
    GadgetronResUnet, up layer 2:
        GadgetronResUnet_UpSample : input size (192, 32, 24), upsampled size (192, 64, 48)
        GadgetronResUnetBasicBlock : input size (192, 64, 48), output size (96, 64, 48) --> (96, 64, 48)
        GadgetronResUnetBasicBlock : input size (96, 64, 48), output size (96, 64, 48) --> (96, 64, 48)
        GadgetronResUnetBasicBlock : input size (96, 64, 48), output size (96, 64, 48) --> (96, 64, 48)
        GadgetronResUnetBasicBlock : input size (96, 64, 48), output size (96, 64, 48) --> (96, 64, 48)
------------------------------------------------------------
Output layer (96, 64, 48) --> (1, 64, 48)
------------------------------------------------------------
aif_trainer = training.GadgetronTwoClassSeg_PerfAIF(model, 
                                   optimizer, 
                                   criterion, 
                                   loader_for_train, 
                                   loader_for_val, 
                                   p_thres=p_thres, 
                                   scheduler=scheduler, 
                                   epochs=num_epochs, 
                                   device=device, 
                                   x_dtype=torch.float32, 
                                   y_dtype=torch.float32, 
                                   early_stopping_thres = 100,                              
                                   print_every=print_every, 
                                   small_data_mode = False, 
                                   writer=writer, 
                                   model_folder="aif_training/")
epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)
50
Start training ... 
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)
----------------------------------------
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-384-ebf2cf2d36fb> in <module>
----> 1 epochs_traning, epochs_validation, best_model, loss_all, epochs_acc_class = aif_trainer.train(verbose=True, epoch_to_load=-1, save_model_epoch=True)

~\Desktop\New folder (2) temp\deep_learning\aif_detection\training\training_base.py in train(self, verbose, epoch_to_load, save_model_epoch)
    113 
    114             random.seed()
--> 115             tq = tqdm_notebook(total=(len(self.loader_for_train) * batch_size), file=sys.stdout)
    116             tq.set_description('Epoch {}, total {}'.format(e, self.epochs))
    117 

~\Anaconda3\envs\qperf\lib\site-packages\tqdm\__init__.py in tqdm_notebook(*args, **kwargs)
     25          "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`",
     26          TqdmDeprecationWarning, stacklevel=2)
---> 27     return _tqdm_notebook(*args, **kwargs)
     28 
     29 

~\Anaconda3\envs\qperf\lib\site-packages\tqdm\notebook.py in __init__(self, *args, **kwargs)
    246         unit_scale = 1 if self.unit_scale is True else self.unit_scale or 1
    247         total = self.total * unit_scale if self.total else self.total
--> 248         self.container = self.status_printer(self.fp, total, self.desc, self.ncols)
    249         self.container.pbar = self
    250         if display_here:

~\Anaconda3\envs\qperf\lib\site-packages\tqdm\notebook.py in status_printer(_, total, desc, ncols)
    117                 "/user_install.html")
    118         if total:
--> 119             pbar = IProgress(min=0, max=total)
    120         else:  # No total? Show info style bar with no progress tqdm status
    121             pbar = IProgress(min=0, max=1)

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in __new__(cls, *args, **kwargs)
    956         else:
    957             inst = new_meth(cls, *args, **kwargs)
--> 958         inst.setup_instance(*args, **kwargs)
    959         return inst
    960 

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in setup_instance(self, *args, **kwargs)
    984         self._trait_notifiers = {}
    985         self._trait_validators = {}
--> 986         super(HasTraits, self).setup_instance(*args, **kwargs)
    987 
    988     def __init__(self, *args, **kwargs):

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in setup_instance(self, *args, **kwargs)
    975             else:
    976                 if isinstance(value, BaseDescriptor):
--> 977                     value.instance_init(self)
    978 
    979 

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in instance_init(self, obj)
   1690     def instance_init(self, obj):
   1691         self._resolve_classes()
-> 1692         super(Instance, self).instance_init(obj)
   1693 
   1694     def _resolve_classes(self):

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in instance_init(self, obj)
    518         # use provides a static default, transfer that to obj._trait_values.
    519         with obj.cross_validation_lock:
--> 520             if (self._dynamic_default_callable(obj) is None) \
    521                     and (self.default_value is not Undefined):
    522                 v = self._validate(obj, self.default_value)

~\AppData\Roaming\Python\Python37\site-packages\traitlets\traitlets.py in _dynamic_default_callable(self, obj)
    502             mro = type(obj).mro()
    503             meth_name = '_%s_default' % self.name
--> 504             for cls in mro[:mro.index(self.this_class) + 1]:
    505                 if hasattr(cls, '_trait_default_generators'):
    506                     default_handler = cls._trait_default_generators.get(self.name)

ValueError: <class 'ipywidgets.widgets.widget.Widget'> is not in list

Saving the model

try:
    best_model_cpu = best_model.cpu().module
except:
    
    best_model_cpu = best_model.cpu()
    
print(best_model_cpu)
GadgetronResUnet(
  (input_layer): GadgetronResUnetInputBlock(
    (conv1): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (down_layers): Sequential(
    (Down layer 0): Sequential(
      (downsample): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (ResBlock0): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
      (ResBlock1): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
    )
    (Down layer 1): Sequential(
      (downsample): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (ResBlock0): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
      (ResBlock1): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
      (ResBlock2): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
    )
  )
  (bridge_layer): Sequential(
    (downsample): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (ResBlock0): GadgetronResUnetBasicBlock(
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dp): Dropout2d(p=0.5, inplace=False)
    )
    (ResBlock1): GadgetronResUnetBasicBlock(
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dp): Dropout2d(p=0.5, inplace=False)
    )
    (ResBlock2): GadgetronResUnetBasicBlock(
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dp): Dropout2d(p=0.5, inplace=False)
    )
  )
  (up_layers): Sequential(
    (Up layer 0): GadgetronResUnet_UpSample(
      (blocks): Sequential(
        (upsample 0): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
        (upsample 1): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
        (upsample 2): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
      )
    )
    (Up layer 1): GadgetronResUnet_UpSample(
      (blocks): Sequential(
        (upsample 0): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(256, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
        (upsample 1): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
      )
    )
    (Up layer 2): GadgetronResUnet_UpSample(
      (blocks): Sequential(
        (upsample 0): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
        (upsample 1): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
      )
    )
  )
  (output_conv): Conv2d(96, 3, kernel_size=(1, 1), stride=(1, 1))
)
from datetime import date
today = str(date.today())
print(today)

from time import gmtime, strftime
moment = strftime("%Y%m%d_%H%M%S", gmtime())
print(moment)
2021-01-28
20210128_012515
model_file = './deployment/networks/perf_aif_' + perf_aif_dataset.which_mask + '_network_' + moment + '.pbt'
print(model_file)
./deployment/networks/perf_aif_lv_network_20210128_012515.pbt
torch.save(best_model_cpu, model_file)
model_loaded = torch.load(model_file)
print(model_loaded)
GadgetronResUnet(
  (input_layer): GadgetronResUnetInputBlock(
    (conv1): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (down_layers): Sequential(
    (Down layer 0): Sequential(
      (downsample): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (ResBlock0): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
      (ResBlock1): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
    )
    (Down layer 1): Sequential(
      (downsample): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (ResBlock0): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
      (ResBlock1): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
      (ResBlock2): GadgetronResUnetBasicBlock(
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dp): Dropout2d(p=0.5, inplace=False)
      )
    )
  )
  (bridge_layer): Sequential(
    (downsample): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (ResBlock0): GadgetronResUnetBasicBlock(
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dp): Dropout2d(p=0.5, inplace=False)
    )
    (ResBlock1): GadgetronResUnetBasicBlock(
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dp): Dropout2d(p=0.5, inplace=False)
    )
    (ResBlock2): GadgetronResUnetBasicBlock(
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dp): Dropout2d(p=0.5, inplace=False)
    )
  )
  (up_layers): Sequential(
    (Up layer 0): GadgetronResUnet_UpSample(
      (blocks): Sequential(
        (upsample 0): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
        (upsample 1): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
        (upsample 2): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
      )
    )
    (Up layer 1): GadgetronResUnet_UpSample(
      (blocks): Sequential(
        (upsample 0): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(256, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
        (upsample 1): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
      )
    )
    (Up layer 2): GadgetronResUnet_UpSample(
      (blocks): Sequential(
        (upsample 0): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
        (upsample 1): GadgetronResUnetBasicBlock(
          (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (dp): Dropout2d(p=0.5, inplace=False)
        )
      )
    )
  )
  (output_conv): Conv2d(96, 3, kernel_size=(1, 1), stride=(1, 1))
)